/*========================== begin_copyright_notice ============================

Copyright (C) 2017-2024 Intel Corporation

SPDX-License-Identifier: MIT

============================= end_copyright_notice ===========================*/

//===----------------------------------------------------------------------===//
//
/// CMImpParam
/// ----------
///
/// As well as explicit kernel args declared in the CM kernel function, certain
/// implicit args are also passed. These fall into 3 categories:
///
/// 1. fields set up in r0 by the hardware, depending on which dispatch method
///    is being used (e.g. media walker);
///
/// 2. implicit args set up along with the explicit args in CURBE by the CM
///    runtime.
///
/// 3. implicit OCL/L0 args set up, e.g. private base, byval arg linearization.
///
/// The r0 implicit args are represented in LLVM IR by special intrinsics, and
/// the GenX backend generates these to special reserved vISA registers.
///
/// For the CM runtime implicit args in (2) above, in vISA 3.2 and earlier,
/// these were also represented by LLVM special intrinsics and vISA special
/// reserved vISA registers. Because they are specific to the CM runtime, and
/// not any other user of vISA, vISA 3.3 has removed them, and instead they are
/// handled much like other kernel args in the input table.
///
/// The *kind* byte in the input table has two fields:
///
/// * the *category* field, saying whether the input is general/surface/etc;
///
/// * the *provenance* field, saying whether the input is an explicit one from
///   the CM source, or an implicit one generated by this pass. This is a
///   protocol agreed between the CM compiler (in fact this pass) and the CM
///   runtime.
///
/// Within the CM compiler, the vISA input table for a kernel is represented by
/// an array of kind bytes, each one corresponding to an argument of the kernel
/// function.
///
/// Clang codegen still generates special intrinsics for these CM runtime
/// implicit args. It is the job of this CMImpParam pass to transform those
/// intrinsics:
///
/// * where the intrinsic for a CM runtime implicit arg is used somewhere:
///
///   - a global variable is created for it;
///
///   - for any kernel that uses the implicit arg (or can reach a subroutine
///   that
///     uses it), the implicit arg is added to the input table in the kernel
///     metadata and as an extra arg to the definition of the kernel itself,
///     and its value is stored into the global variable;
///
///   - for any fixed signature function (implicit args cannot be passed as an
///     additional function parameter) the implicit arg is loaded from implicit
///     args buffer (it is always available during the execution) and then
///     stored into the corresponding global variable;
///
///   - kernels that require implicit args buffer being allocated are marked;
///
/// * each use of the intrinsic for a CM runtime implicit arg is transformed
/// into
///   a load of the corresponding global variable.
///
/// Like any other global variable, the subsequent CMABI pass turns the global
/// variable for an implicit arg into local variable(s) passed into subroutines
/// if necessary.
///
/// This pass also linearizes kernel byval arguments.
/// If a kernel has an input pointer argument with byval attribute, it means
/// that it will be passed as a value with the argument's size = sizeof(the
/// type), not sizeof(the type *). To support such kinds of arguments, VC (as
/// well as scalar IGC) makes implicit linearization, e.g.
///
///   %struct.s1 = type { [2 x i32], i8 } ===> i32, i32, i8
///
/// This implicit linearization is added as kernel arguments and mapped via
/// metadata to the original explicit byval argument.
///
///   %struct.s1 = type { [2 x i32], i8 }
///
///   declare i32 @foo(%struct.s1* byval(%struct.s1) "VCArgumentDesc"="svmptr_t"
///                    "VCArgumentIOKind"="0" "VCArgumentKind"="0" %arg, i64
///                    %arg1);
///
/// Will be transformed into (byval args uses will be changed in
/// CMKernelArgOffset)
///
///   declare i32 @foo(%struct.s1* byval(%struct.s1) "VCArgumentDesc"="svmptr_t"
///                     "VCArgumentIOKind"="0" "VCArgumentKind"="0" %arg, i64
///                     %arg1, i32 %__arg_lin__arg_0, i32 %__arg_lin__arg_1, i8
///                     %__arg_lin__arg_2);
///
/// Additionally, information about these implicit linearization will be written
/// to kernel metadata as internal::KernelMDOp::LinearizationArgs. It stores
/// mapping between explicit byval argument and its linearization.
///
//===----------------------------------------------------------------------===//

#include "vc/GenXOpts/GenXOpts.h"
#include "vc/Utils/GenX/KernelInfo.h"

#include "vc/InternalIntrinsics/InternalIntrinsics.h"
#include "vc/Support/GenXDiagnostic.h"
#include "vc/Utils/GenX/IRBuilder.h"
#include "vc/Utils/GenX/ImplicitArgsBuffer.h"
#include "vc/Utils/GenX/IntrinsicsWrapper.h"
#include "vc/Utils/GenX/PredefinedVariable.h"
#include "vc/Utils/General/DebugInfo.h"
#include "vc/Utils/General/FunctionAttrs.h"
#include "vc/Utils/General/IRBuilder.h"

#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Analysis/CallGraphSCCPass.h"
#include "llvm/GenXIntrinsics/GenXIntrinsics.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

#include <algorithm>
#include <functional>
#include <iterator>
#include <map>
#include <numeric>
#include <optional>
#include <set>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "Probe/Assertion.h"
#include "llvmWrapper/Analysis/CallGraph.h"
#include "llvmWrapper/IR/DerivedTypes.h"
#include "llvmWrapper/IR/Function.h"
#include "llvmWrapper/Support/Alignment.h"
#include <llvmWrapper/ADT/Optional.h>

#define DEBUG_TYPE "CMImpParam"

using namespace llvm;

static cl::opt<bool>
    PayloadInMemoryOpt("cmimpparam-payload-in-memory", cl::init(true),
                       cl::Hidden,
                       cl::desc("Whether the target has payload in memory"));

// Sometimes full list of elements cannot be defined, e.g. list of called
// functions when indirect call instructions are present.
// In this type set vector means that the full list can be defined and the
// vector represents it. When the vector is not set, the full list cannot be
// defined but it is not empty. Empty vector means that the full set can be
// defined but it is just empty.
template <typename T> using MaybeUndefSeq = std::optional<std::vector<T>>;
using MaybeUndefFuncSeq = MaybeUndefSeq<Function *>;
using FunctionRef = std::reference_wrapper<Function>;

// Checks whether the provided vector \p has unique elements.
template <typename T> static bool isUnique(const std::vector<T> &V) {
  std::unordered_set<T> TestUnique{V.begin(), V.end()};
  return TestUnique.size() == V.size();
}

namespace {

// Helper struct to store temporary information for implicit arguments
// linearization.
struct LinearizationElt {
  Type *Ty;
  unsigned Offset;
};
using LinearizedTy = std::vector<LinearizationElt>;
using ArgLinearization = std::unordered_map<Argument *, LinearizedTy>;
using ImplArgIntrSeq = std::vector<CallInst *>;
using IntrIDSet = std::set<unsigned>;
using IntrIDMap = std::unordered_map<Function *, IntrIDSet>;

// Implicit args in this pass are denoted by the corresponding intrinsic ID.
// But not all implicit args have a corresponding intrinsic. So for those args
// pseudo intrinsic IDs are provided. Pseudo ID values are guaranteed to not
// overlap with real instrinsic IDs.
namespace PseudoIntrinsic {
enum Enum : unsigned {
  First = vc::InternalIntrinsic::not_any_intrinsic,
  PrivateBase = First,
  ImplicitArgsBuffer,
  Last
};
} // namespace PseudoIntrinsic

struct CMImpParam : public ModulePass {
  static char ID;
  // Defines whether payload is in memory or on registers. It depends on target
  // architecture.
  bool HasPayloadInMemory = false;
  const DataLayout *DL = nullptr;

#if LLVM_VERSION_MAJOR >= 16
  CallGraph &CG;
  CMImpParam(CallGraph &CG, bool HasPayloadInMemoryIn)
      : ModulePass{ID}, HasPayloadInMemory{HasPayloadInMemoryIn}, CG(CG) {
    initializeCMImpParamPass(*PassRegistry::getPassRegistry());
  }

  CMImpParam(CallGraph &CG)
      : ModulePass{ID}, HasPayloadInMemory{PayloadInMemoryOpt}, CG(CG) {
    initializeCMImpParamPass(*PassRegistry::getPassRegistry());
  }
#else  // LLVM_VERSION_MAJOR
  CMImpParam(bool HasPayloadInMemoryIn)
      : ModulePass{ID}, HasPayloadInMemory{HasPayloadInMemoryIn} {
    initializeCMImpParamPass(*PassRegistry::getPassRegistry());
  }

  CMImpParam() : ModulePass{ID}, HasPayloadInMemory{PayloadInMemoryOpt} {
    initializeCMImpParamPass(*PassRegistry::getPassRegistry());
  }
#endif // LLVM_VERSION_MAJOR

  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<CallGraphWrapperPass>();
  }

  StringRef getPassName() const override { return "CM Implicit Params"; }

  bool runOnModule(Module &M) override;

private:
  void replaceWithGlobal(CallInst *CI);
  void replaceImplicitArgIntrinsics(const ImplArgIntrSeq &Workload);

  LinearizedTy LinearizeAggregateType(Type *AggrTy);
  ArgLinearization GenerateArgsLinearizationInfo(Function &F);

  void
  processKernels(const std::vector<FunctionRef> &Kernels,
                 const IntrIDMap &UsedIntrInfo,
                 const std::unordered_set<Function *> &RequireImplArgsBuffer);
  CallGraphNode *processKernelParameters(Function *F,
                                         const IntrIDSet &UsedImplicits);

  std::pair<IntrIDMap, std::unordered_set<Function *>>
  analyzeFixedSignatureFunctions(Module &M, const IntrIDMap &UsedIntrInfo);
  void processFixedSignatureFunction(Function &F,
                                     const IntrIDSet &UsedIntrInfo);
  void storeImplArgInFixedSignatureFunction(unsigned IID,
                                            Value &ImplArgsBufferPtr,
                                            IRBuilder<> &IRB);

  vc::ThreadPayloadKind getThreadPayloadKind() const {
    if (HasPayloadInMemory)
      return vc::ThreadPayloadKind::InMemory;
    return vc::ThreadPayloadKind::OnRegister;
  }

  static Value *getValue(Metadata *M) {
    if (auto VM = dyn_cast<ValueAsMetadata>(M))
      return VM->getValue();
    return nullptr;
  }

  // Convert to implicit thread payload related intrinsics.
  void ConvertToOCLPayload(Module &M);

  uint32_t MapToKind(unsigned IID) {
    using namespace vc;
    switch (IID) {
    default:
      return KernelMetadata::AK_NORMAL;
    case InternalIntrinsic::assert_buffer:
      return static_cast<KernelMetadata::ImpValue>(KernelMetadata::AK_NORMAL) |
             KernelMetadata::IMP_OCL_ASSERT_BUFFER;
    case vc::InternalIntrinsic::print_buffer:
      return static_cast<KernelMetadata::ImpValue>(KernelMetadata::AK_NORMAL) |
             KernelMetadata::IMP_OCL_PRINTF_BUFFER;
    case InternalIntrinsic::sync_buffer:
      return static_cast<KernelMetadata::ImpValue>(KernelMetadata::AK_NORMAL) |
             KernelMetadata::IMP_OCL_SYNC_BUFFER;
    case GenXIntrinsic::genx_local_size:
      return static_cast<KernelMetadata::ImpValue>(KernelMetadata::AK_NORMAL) |
             KernelMetadata::IMP_LOCAL_SIZE;
    case GenXIntrinsic::genx_local_id:
    case GenXIntrinsic::genx_local_id16:
      return static_cast<KernelMetadata::ImpValue>(KernelMetadata::AK_NORMAL) |
             KernelMetadata::IMP_LOCAL_ID;
    case GenXIntrinsic::genx_group_count:
      return static_cast<KernelMetadata::ImpValue>(KernelMetadata::AK_NORMAL) |
             KernelMetadata::IMP_GROUP_COUNT;
    case GenXIntrinsic::genx_get_scoreboard_deltas:
      return static_cast<KernelMetadata::ImpValue>(KernelMetadata::AK_NORMAL) |
             KernelMetadata::IMP_SB_DELTAS;
    case GenXIntrinsic::genx_get_scoreboard_bti:
      return static_cast<KernelMetadata::ImpValue>(KernelMetadata::AK_SURFACE) |
             KernelMetadata::IMP_SB_BTI;
    case GenXIntrinsic::genx_get_scoreboard_depcnt:
      return static_cast<KernelMetadata::ImpValue>(KernelMetadata::AK_SURFACE) |
             KernelMetadata::IMP_SB_DEPCNT;
    case PseudoIntrinsic::PrivateBase:
      return static_cast<KernelMetadata::ImpValue>(KernelMetadata::AK_NORMAL) |
             KernelMetadata::IMP_OCL_PRIVATE_BASE;
    case PseudoIntrinsic::ImplicitArgsBuffer:
      return static_cast<KernelMetadata::ImpValue>(KernelMetadata::AK_NORMAL) |
             KernelMetadata::IMP_IMPL_ARGS_BUFFER;
    }
    return KernelMetadata::AK_NORMAL;
  }

  GlobalVariable *getOrCreateGlobalForIID(Function *F, unsigned IID) {
    if (GlobalsMap.count(IID))
      return GlobalsMap[IID];

    Type *Ty = getIntrinRetType(F->getContext(), IID);
    IGC_ASSERT(Ty);

    auto IntrinsicName = vc::getAnyName(IID, ArrayRef<Type *>());
    GlobalVariable *NewVar = new GlobalVariable(
        *F->getParent(), Ty, false, GlobalVariable::InternalLinkage,
        UndefValue::get(Ty), "__imparg_" + IntrinsicName);
    GlobalsMap[IID] = NewVar;

    addDebugInfoForImplicitGlobal(*NewVar, IntrinsicName);

    return NewVar;
  }

  static void addDebugInfoForImplicitGlobal(GlobalVariable &Var,
                                            StringRef Name) {
    auto &M = *Var.getParent();
    if (!vc::DIBuilder::checkIfModuleHasDebugInfo(M))
      return;

    std::string DiName = (Twine("__") + Name).str();
    std::replace(DiName.begin(), DiName.end(), '.', '_');

    vc::DIBuilder DBuilder(M);
    auto *DGVType = DBuilder.translateTypeToDIType(*Var.getValueType());
    if (!DGVType) {
      LLVM_DEBUG(dbgs() << "ERROR: could not create debug info for implict var:"
                        << Var << "\n");
      return;
    }
    auto *GVE =
        DBuilder.createGlobalVariableExpression(DiName, DiName, DGVType);
    Var.addDebugInfo(GVE);
  }

  static Type *getIntrinRetType(LLVMContext &Context, unsigned IID) {
    switch (IID) {
    case vc::InternalIntrinsic::assert_buffer:
    case vc::InternalIntrinsic::print_buffer:
    case vc::InternalIntrinsic::sync_buffer:
    case PseudoIntrinsic::PrivateBase:
    case PseudoIntrinsic::ImplicitArgsBuffer:
      return llvm::Type::getInt64Ty(Context);
    case GenXIntrinsic::genx_local_id:
    case GenXIntrinsic::genx_local_size:
    case GenXIntrinsic::genx_group_count:
      return IGCLLVM::FixedVectorType::get(llvm::Type::getInt32Ty(Context), 3);
    case GenXIntrinsic::genx_local_id16:
      return IGCLLVM::FixedVectorType::get(llvm::Type::getInt16Ty(Context), 3);
    default:
      // Should be able to extract the type from the intrinsic
      // directly as no overloading is required (if it is then
      // you need to define specific type in a case statement above)
      FunctionType *FTy = dyn_cast_or_null<FunctionType>(
          GenXIntrinsic::getAnyType(Context, IID));
      if (FTy)
        return FTy->getReturnType();
    }
    return nullptr;
  }

  // GlobalVariables that have been created for an intrinsic
  SmallDenseMap<unsigned, GlobalVariable *> GlobalsMap;
};

// A helper class to recursively traverse call graph and collect all the
// required implicit args.
// Only temporary objects should be constructed. Usage:
// CallGraphTraverser{CG, UsedIntr}.collectIndirectlyUsedImplArgs(F);
class CallGraphTraverser {
  const CallGraph &CG;
  const IntrIDMap &UsedIntr;
  std::unordered_set<Function *> Visited;
  IntrIDSet CollectedIID;
  MaybeUndefFuncSeq CalledFixedSignFuncs{MaybeUndefFuncSeq::value_type{}};

public:
  CallGraphTraverser(const CallGraph &CGIn, const IntrIDMap &UsedIntrIn)
      : CG{CGIn}, UsedIntr{UsedIntrIn} {}

  // Returns a pair of indirectly used implicit args and called fixed signature
  // functions. Indirectly used implicit args are those implicit args that are
  // used in the provided function \p F and its recursively called (called from
  // \p F, called from those that called from \p, called from those that called
  // from those that...) non-fixed signature functions. Fixed signature
  // functions or indirect call stop the traversal, their indirectly used
  // implicit args aren't collected. The 2nd return value describes whether
  // some fixed signature functions were called and if so what functions were
  // called if this set can be defined (case when some were but the distinct
  // set cannot be defined is considered).
  std::pair<IntrIDSet, MaybeUndefFuncSeq>
  collectIndirectlyUsedImplArgs(Function &F) && {
    IGC_ASSERT_MESSAGE(
        vc::isFixedSignatureDefinition(F) | vc::isKernel(&F),
        "entry point must be a fixed signature function or a kernel");
    visitFunction</*IsEntry =*/true>(F);
    if (CalledFixedSignFuncs.has_value()) {
      IGC_ASSERT_MESSAGE(isUnique(CalledFixedSignFuncs.value()),
                         "values in CalledFixedSignFuncs must be unique");
    }
    return {CollectedIID, CalledFixedSignFuncs};
  }

private:
  template <bool IsEntry = false> void visitFunction(Function &F);
};

} // namespace

static ImplArgIntrSeq collectImplicitArgIntrinsics(Module &M);
static IntrIDMap fillUsedIntrMap(const ImplArgIntrSeq &Workload);

static bool isPseudoIntrinsic(unsigned IID) {
  return IID >= PseudoIntrinsic::First && IID < PseudoIntrinsic::Last;
}

// Checks whether kernel calls some function with a fixed signature that uses
// implicit args or may call such function (in case of some externally defined
// function, or indirect call). In this case kernel should have access to
// implicit arg buffer.
bool kernelRequiresImplArgBuffer(
    const MaybeUndefFuncSeq &CalledFixedSignFuncs,
    const std::unordered_set<Function *> &RequireImplArgsBuffer) {
  if (!CalledFixedSignFuncs.has_value())
    // Set of called functions cannot be defined. Presume that the buffer is
    // required.
    return true;
  return llvm::any_of(CalledFixedSignFuncs.value(),
                      [&RequireImplArgsBuffer](Function *F) {
                        return RequireImplArgsBuffer.count(F);
                      });
}

// Creates predefined vISA variables that are required to work with implicit
// arguments in extern and indirect functions for architectures with payload on
// registers.
// The required predefined variables must not be created before calling this
// function. If you not sure whether they are created or not, use
// \p getOrCreatePredefVars.
static std::pair<GlobalVariable *, GlobalVariable *>
createPredefVars(Module &M) {
  return {&vc::PredefVar::createImplicitArgsBuffer(M),
          &vc::PredefVar::createLocalIDBuffer(M)};
}

// Returns implicit args buffer related predefined variables. Returns existing
// ones or creates new ones.
static std::pair<GlobalVariable *, GlobalVariable *>
getOrCreatePredefVars(Module &M) {
  auto *ImplArgBuffer = M.getNamedGlobal(vc::PredefVar::ImplicitArgsBufferName);
  if (ImplArgBuffer) {
    auto *LocalIDBuffer = M.getNamedGlobal(vc::PredefVar::LocalIDBufferName);
    IGC_ASSERT_MESSAGE(
        LocalIDBuffer,
        "If implict args buffer predefined variable is created, local ID "
        "buffer predefined variable must be created too");
    return {ImplArgBuffer, LocalIDBuffer};
  }
  return createPredefVars(M);
}

// Inserts code that initializes local ID buffer predefined variable
// \p LocalIDBufferVar with local ID buffer pointer.
// Local ID buffer is allocated on \p Kernel stack and initialized with the
// value of local ID implicit argument (the argument must have already be added
// to the kernel arguments).
static void initializeLocalIDBufferVariable(Function &Kernel,
                                            GlobalVariable &LocalIDBufferVar,
                                            IRBuilder<> &IRB) {
  using namespace vc::ImplicitArgs;
  IGC_ASSERT_MESSAGE(vc::isKernel(Kernel),
                     "wrong argument: a kernel must be provided");

  Argument &LocalIDArg =
      vc::getImplicitArg(Kernel, vc::KernelMetadata::IMP_LOCAL_ID);
  auto *LocalIDBufferPtr =
      IRB.CreateAlloca(&LocalID::getType(*LocalIDBufferVar.getParent()),
                       vc::AddrSpace::Private, nullptr, "loc.id.buffer");

  std::array<LocalID::Indices::Enum, 3> Indices = {
      LocalID::Indices::X, LocalID::Indices::Y, LocalID::Indices::Z};
  for (auto Index : Indices) {
    auto *Element =
        IRB.CreateExtractElement(&LocalIDArg, Index, "loc.id." + Twine{Index});
    auto *Pointer = IRB.CreateGEP(
        LocalIDBufferPtr->getAllocatedType(), LocalIDBufferPtr,
        {IRB.getInt32(0), IRB.getInt32(Index)}, "loc.id.ptr." + Twine{Index});
    IRB.CreateStore(Element, Pointer);
  }

  auto *LocalIDBufferIntPtr = IRB.CreatePtrToInt(
      LocalIDBufferPtr, IRB.getInt64Ty(), "loc.id.buf.int.ptr");
  vc::createWriteVariableRegion(LocalIDBufferVar, *LocalIDBufferIntPtr, IRB);
}

// Special prologue must be added to kernel for architectures with payload on
// registers. Implicit argument buffer implicit argument and local ids buffer
// pointer must be copied into the corresponding predefined variables. This
// function inserts the corresponding code in the \p Kernel prologue.
static void addKernelPrologue(Function &Kernel) {
  using namespace vc::ImplicitArgs;
  IGC_ASSERT_MESSAGE(vc::isKernel(Kernel),
                     "wrong argument: a kernel must be provided");

  IRBuilder<> IRB{&*Kernel.getEntryBlock().getFirstInsertionPt()};
  auto [ImplArgsBufferVar, LocalIDBufferVar] =
      getOrCreatePredefVars(*Kernel.getParent());

  // Initializing implicit args buffer predefined variable.
  Argument &ImplArgsBufferArg =
      vc::getImplicitArg(Kernel, vc::KernelMetadata::IMP_IMPL_ARGS_BUFFER);
  vc::createWriteVariableRegion(*ImplArgsBufferVar, ImplArgsBufferArg, IRB);

  initializeLocalIDBufferVariable(Kernel, *LocalIDBufferVar, IRB);
}

bool CMImpParam::runOnModule(Module &M) {
  DL = &M.getDataLayout();

  // Apply necessary changes if kernels are compiled for OpenCL runtime.
  ConvertToOCLPayload(M);

  // Analyze functions for implicit use intrinsic invocation
  ImplArgIntrSeq Workload = collectImplicitArgIntrinsics(M);
  std::vector<FunctionRef> Kernels{vc::kernel_begin(M), vc::kernel_end(M)};
  IntrIDMap UsedIntrInfo = fillUsedIntrMap(Workload);
  auto [FixedSignFuncInfo, RequireImplArgsBuffer] =
      analyzeFixedSignatureFunctions(M, UsedIntrInfo);

  if (Workload.empty() && Kernels.empty() && FixedSignFuncInfo.empty())
    // If ConvertToOCLPayload changed code, workload wouldn't be empty (there
    // would be at least local_id16 intrinsics). So returning false here is
    // correct.
    return false;

  replaceImplicitArgIntrinsics(Workload);

  // Predefined variables are required when fixed signature functions with
  // implicit args are present. The variables are required to access implicit
  // args. This approach is used only for architectures with payload on
  // registers.
  if (!FixedSignFuncInfo.empty() && !HasPayloadInMemory)
    createPredefVars(M);

  for (auto &[F, RequiredImplArgs] : FixedSignFuncInfo)
    processFixedSignatureFunction(*F, RequiredImplArgs);

  // Kernel transformation should go last since it invalidates the collected
  // data: kernel functions are changed.
  processKernels(Kernels, UsedIntrInfo, RequireImplArgsBuffer);

  return true;
}

void CMImpParam::processKernels(
    const std::vector<FunctionRef> &Kernels, const IntrIDMap &UsedIntrInfo,
    const std::unordered_set<Function *> &RequireImplArgsBuffer) {
#if LLVM_VERSION_MAJOR < 16
  CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
#endif

  for (Function &Kernel : Kernels) {
    // Traverse the call graph to determine what the total implicit uses are for
    // the top level kernels.
    auto [RequiredImplArgs, CalledFixedSignFuncs] =
        CallGraphTraverser{CG, UsedIntrInfo}.collectIndirectlyUsedImplArgs(
            Kernel);
    bool KernelRequiresImplArgBuffer = kernelRequiresImplArgBuffer(
        CalledFixedSignFuncs, RequireImplArgsBuffer);
    // For kernels that require implicit args buffer and architectures that
    // have payload on registers a prologue must be inserted. In this prologue
    // implicit args buffer implicit arg is copied into the corresponding
    // predefined variable.
    bool KernelRequiresPrologueInsertion =
        KernelRequiresImplArgBuffer && !HasPayloadInMemory;

    if (KernelRequiresImplArgBuffer)
      Kernel.addFnAttr(vc::ImplicitArgs::KernelAttr);
    if (KernelRequiresPrologueInsertion) {
      RequiredImplArgs.emplace(PseudoIntrinsic::ImplicitArgsBuffer);
      RequiredImplArgs.emplace(GenXIntrinsic::genx_local_id16);
    }
    // For OCL/L0 RT we should unconditionally add implicit PRIVATE_BASE
    // argument which is not supported on CM RT.
    RequiredImplArgs.emplace(PseudoIntrinsic::PrivateBase);
    vc::internal::createInternalMD(Kernel);
    if (!RequiredImplArgs.empty()) {
      CallGraphNode *NewKernelNode =
          processKernelParameters(&Kernel, RequiredImplArgs);
      if (KernelRequiresPrologueInsertion)
        addKernelPrologue(*NewKernelNode->getFunction());
    }
  }
}

// Returns:
//    0: Map from a fixed signature function to the set of implicit args used in
//       the function and its subroutines. Implicit args are defined by IIDs.
//    1: Set of functions which require kernel that calls them to have implicit
//       args buffer. This set is wider than set of functions in the 0 return
//       value since implicit args may be used not only in the function and
//       subroutines but also in some called fixed signature function which
//       doesn't require loading implicit args in the considered function but
//       does require implicit args buffer to be present to be able to load
//       implicit args on their side.
// Note: by subroutines above were meant functions which signatures will be
// changed and implicit args in which will be passed as additional parameters.
// This also may be an internal stack call.
std::pair<IntrIDMap, std::unordered_set<Function *>>
CMImpParam::analyzeFixedSignatureFunctions(Module &M,
                                           const IntrIDMap &UsedIntrInfo) {
  IntrIDMap FixedSignFuncInfo;
  std::unordered_set<Function *> RequireImplArgsBuffer;
#if LLVM_VERSION_MAJOR < 16
  CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
#endif
  auto FixedSignatureDefinitions = make_filter_range(
      M, [](const Function &F) { return vc::isFixedSignatureDefinition(F); });

  for (Function &F : FixedSignatureDefinitions) {
    auto [RequiredImplArgs, CalledFixedSignFuncs] =
        CallGraphTraverser{CG, UsedIntrInfo}.collectIndirectlyUsedImplArgs(F);

    // Function should be marked if it uses implicit args or it calls some
    // function that may use it via implicit args buffer.
    if (!RequiredImplArgs.empty() || !CalledFixedSignFuncs.has_value() ||
        !CalledFixedSignFuncs.value().empty())
      RequireImplArgsBuffer.emplace(&F);

    if (!RequiredImplArgs.empty())
      FixedSignFuncInfo.emplace(&F, std::move(RequiredImplArgs));
  }
  return {FixedSignFuncInfo, RequireImplArgsBuffer};
}

// Accesses all the required implicit args and stores them into implicit arg
// global variables. The set of requred implicit args must be provided in
// \p UsedIntrInfo.
void CMImpParam::processFixedSignatureFunction(Function &F,
                                               const IntrIDSet &UsedIntrInfo) {
  IRBuilder<> IRB{&*F.getEntryBlock().getFirstInsertionPt()};
  auto &ImplArgsBufferPtr =
      vc::ImplicitArgs::Buffer::getPointer(IRB, getThreadPayloadKind());
  for (unsigned IID : UsedIntrInfo)
    storeImplArgInFixedSignatureFunction(IID, ImplArgsBufferPtr, IRB);
}

// Returns an array of indices of implicit args buffer structure fields that
// should be loaded to obtain the value of an implicit arg defined by the
// corresponding intrinsic ID.
template <unsigned IID>
std::array<vc::ImplicitArgs::Buffer::Indices::Enum, 3>
getIABufferIndicesForIID();

template <>
std::array<vc::ImplicitArgs::Buffer::Indices::Enum, 3>
getIABufferIndicesForIID<GenXIntrinsic::genx_local_size>() {
  return {vc::ImplicitArgs::Buffer::Indices::LocalSizeX,
          vc::ImplicitArgs::Buffer::Indices::LocalSizeY,
          vc::ImplicitArgs::Buffer::Indices::LocalSizeZ};
}

template <>
std::array<vc::ImplicitArgs::Buffer::Indices::Enum, 3>
getIABufferIndicesForIID<GenXIntrinsic::genx_group_count>() {
  return {vc::ImplicitArgs::Buffer::Indices::GroupCountX,
          vc::ImplicitArgs::Buffer::Indices::GroupCountY,
          vc::ImplicitArgs::Buffer::Indices::GroupCountZ};
}

// Inserts a code that loads 3-component implicit argument from the buffer.
// An implicit argument is defined via the corresponding intrinsic ID.
// Arguments:
//    \p ImplArgsBufferPtr - a pointer to implicit args buffer structure;
//    \p IRB - IR builder used to insert the code.
template <unsigned IID>
Value &loadVec3ArgFromIABuffer(Value &ImplArgsBufferPtr, IRBuilder<> &IRB) {
  auto Indices = getIABufferIndicesForIID<IID>();
  std::array<Value *, 3> VectorElements;
  std::transform(
      Indices.begin(), Indices.end(), VectorElements.begin(),
      [&ImplArgsBufferPtr, &IRB](vc::ImplicitArgs::Buffer::Indices::Enum Idx) {
        return &vc::ImplicitArgs::Buffer::loadField(ImplArgsBufferPtr, Idx, IRB,
                                                    "impl.arg.vec.elem");
      });
  return *vc::accumulateVector(VectorElements, VectorElements.size(), IRB,
                               "impl.arg.vec");
}

// Inserts a code that loads local IDs from the buffer and represents them as
// a 3-component vector.
// Arguments:
//    \p ImplArgsBufferPtr - a pointer to implicit args buffer structure;
//    \p IRB - IR builder used to insert the code.
static Value &loadLocalIDFromIABuffer(Value &ImplArgsBufferPtr,
                                      IRBuilder<> &IRB,
                                      vc::ThreadPayloadKind PayloadKind) {
  using namespace vc::ImplicitArgs;
  std::array<LocalID::Indices::Enum, 3> Indices = {
      LocalID::Indices::X, LocalID::Indices::Y, LocalID::Indices::Z};
  auto &LIDStructPtr = LocalID::getPointer(ImplArgsBufferPtr, IRB, PayloadKind);
  std::array<Value *, 3> VectorElements;
  std::transform(Indices.begin(), Indices.end(), VectorElements.begin(),
                 [&LIDStructPtr, &IRB](LocalID::Indices::Enum Idx) {
                   return &LocalID::loadField(LIDStructPtr, Idx, IRB,
                                              "ia.local.id.elem");
                 });
  return *vc::accumulateVector(VectorElements, VectorElements.size(), IRB,
                               "impl.arg.vec");
}

// Inserts a code that loads implicit argument from the buffer.
// Arguments:
//    \p IID - the ID of intrinsic that corresponds to the required implicit
//             arg;
//    \p ImplArgsBufferPtr - a pointer to implicit args buffer structure;
//    \p IRB - IR builder used to insert the code.
static Value &loadArgFromIABuffer(unsigned IID, Value &ImplArgsBufferPtr,
                                  IRBuilder<> &IRB,
                                  vc::ThreadPayloadKind PayloadKind) {
  switch (IID) {
  default:
    IGC_ASSERT_EXIT_MESSAGE(0, "unexpected intrinsic id");
    return ImplArgsBufferPtr;
  case GenXIntrinsic::genx_local_id16:
    return loadLocalIDFromIABuffer(ImplArgsBufferPtr, IRB, PayloadKind);
  case GenXIntrinsic::genx_local_id:
    IGC_ASSERT_EXIT_MESSAGE(
        0, "IA buffer is supported only for OCL/L0 so local.id "
           "must have been already transformed into local.id16");
    return ImplArgsBufferPtr;
  case GenXIntrinsic::genx_get_scoreboard_deltas:
  case GenXIntrinsic::genx_get_scoreboard_bti:
  case GenXIntrinsic::genx_get_scoreboard_depcnt:
    // It is an assertion since the diagnostic must have been happen before.
    IGC_ASSERT_EXIT_MESSAGE(
        0, "IA buffer is supported only for OCL/L0, scoreboard "
           "builtins should not appear for those runtimes");
    return ImplArgsBufferPtr;
  case GenXIntrinsic::genx_local_size:
    return loadVec3ArgFromIABuffer<GenXIntrinsic::genx_local_size>(
        ImplArgsBufferPtr, IRB);
  case GenXIntrinsic::genx_group_count:
    return loadVec3ArgFromIABuffer<GenXIntrinsic::genx_group_count>(
        ImplArgsBufferPtr, IRB);
  case vc::InternalIntrinsic::print_buffer:
    return vc::ImplicitArgs::Buffer::loadField(
        ImplArgsBufferPtr, vc::ImplicitArgs::Buffer::Indices::PrintfBufferPtr,
        IRB, "printf.buffer.ptr");
  }
}

// Accesses the required implicit arg and stores it into the corresponding
// implicit arg global variable.
// Arguments:
//    \p IID - the ID of intrinsic that corresponds to the required implicit
//             arg;
//    \p ImplArgsBufferPtr - a pointer to implicit args buffer structure;
//    \p IRB - IR builder used to insert the code.
void CMImpParam::storeImplArgInFixedSignatureFunction(unsigned IID,
                                                      Value &ImplArgsBufferPtr,
                                                      IRBuilder<> &IRB) {
  auto &ImplicitArg =
      loadArgFromIABuffer(IID, ImplArgsBufferPtr, IRB, getThreadPayloadKind());
  IGC_ASSERT_MESSAGE(GlobalsMap.count(IID),
                     "must have a corresponding global since the arg use was "
                     "already replaced with a load from the global");
  IRB.CreateStore(&ImplicitArg, GlobalsMap[IID]);
}

// Replace the given instruction with a load from a global
// The method erases the original call instruction.
void CMImpParam::replaceWithGlobal(CallInst *CI) {
  IGC_ASSERT_MESSAGE(GenXIntrinsic::isGenXIntrinsic(CI) ||
                         vc::InternalIntrinsic::isInternalIntrinsic(CI),
                     "genx or vc internal intrinsic is expected");
  auto IID = vc::getAnyIntrinsicID(CI->getCalledFunction());
  GlobalVariable *GV = getOrCreateGlobalForIID(CI->getFunction(), IID);
  LoadInst *Load = new LoadInst(
      GV->getValueType(), GV, "",
      /* isVolatile */ false, IGCLLVM::getCorrectAlign(GV->getAlignment()), CI);
  Load->takeName(CI);
  Load->setDebugLoc(CI->getDebugLoc());
  CI->replaceAllUsesWith(Load);
  CI->eraseFromParent();
}

static bool isSupportedAggregateArgument(Argument &Arg) {
  if (!Arg.getType()->isPointerTy())
    return false;
  if (!Arg.hasByValAttr())
    return false;

  auto *Ty = Arg.getParent()->getParamByValType(Arg.getArgNo());
  auto *STy = cast<StructType>(Ty);
  IGC_ASSERT(!STy->isOpaque());
  return true;
}

// A helper structure to store current state of the aggregate traversal.
struct PendingTypeInfo {
  Type *Ty;         // Type to decompose
  unsigned NextElt; // Subelement number to decompose next
  unsigned Offset;  // Offset for the trivial type in Ty
};

// Byval aggregate arguments must be linearized. This function decomposes the
// aggregate type into primitive types recursively.
// Example:
//   struct s1 {
//     struct s2 {
//       int a;
//     };
//     char b;
//   };
//
//                Pending(stack) | LinTy(output)
// Start:
//                s1, 0, 0       | -
// Iteration 0:
//                s1, 1, 4       | -
//                s2, 0, 0       |
//   Comment: two elements in stack. s1, 1, 4 means subtype number 1 in the
//   s1 must be decomposed. The first trivial type in the 1 subtype of s1 will
//   have offset = 4. Note that this subtype may be also an aggregate type. In
//   this case, offset = 4 will be propagated to the first nested trivial type.
//   It is a recursive function, rewritten to use stack, so as not to have
//   recursion problems.
// Iteration 1:
//                s1, 1, 4       | -
//                int,0, 0       | -
// Iteration 2:
//                s1, 1, 4       | int, 0
// Iteration 3:
//                char, 0, 4     | int, 0
// Iteration 4:
//                -              | int, 0
//                               | char, 4
//
LinearizedTy CMImpParam::LinearizeAggregateType(Type *AggrTy) {
  LinearizedTy LinTy;

  std::stack<PendingTypeInfo> Pending;
  Pending.push({AggrTy, 0, 0});

  while (!Pending.empty()) {
    PendingTypeInfo Info = Pending.top();
    Pending.pop();
    Type *CurTy = Info.Ty;
    unsigned CurElt = Info.NextElt;
    unsigned NextElt = CurElt + 1;
    if (auto *STy = dyn_cast<StructType>(CurTy)) {
      unsigned NumElts = STy->getStructNumElements();
      const StructLayout *Layout = DL->getStructLayout(STy);

      IGC_ASSERT(CurElt < NumElts);
      Type *EltType = STy->getElementType(CurElt);
      if (NumElts > NextElt) {
        unsigned CurOffset = Layout->getElementOffset(CurElt);
        unsigned EltOffset = Layout->getElementOffset(NextElt) - CurOffset;
        Pending.push({CurTy, NextElt, Info.Offset + EltOffset});
      }
      Pending.push({EltType, 0, Info.Offset});

    } else if (auto *ATy = dyn_cast<ArrayType>(CurTy)) {
      unsigned NumElts = ATy->getNumElements();
      Type *EltTy = CurTy->getContainedType(0);
      unsigned EltSize = DL->getTypeStoreSize(EltTy);

      if (NumElts > NextElt)
        Pending.push({Info.Ty, NextElt, Info.Offset + EltSize});
      Pending.push({EltTy, 0, Info.Offset});
    } else
      LinTy.push_back({CurTy, Info.Offset});
  }

  return LinTy;
}

// For each byval aggregate calculate types of implicit args and their offsets
// in this aggregate.
ArgLinearization CMImpParam::GenerateArgsLinearizationInfo(Function &F) {
  ArgLinearization Lin;
  for (auto &Arg : F.args()) {
    if (!isSupportedAggregateArgument(Arg))
      continue;

    IGC_ASSERT(isa<PointerType>(Arg.getType()));
    auto *STy = cast<StructType>(F.getParamByValType(Arg.getArgNo()));
    Lin[&Arg] = LinearizeAggregateType(STy);
  }
  return Lin;
}

static bool isImplicitArgIntrinsic(const Function &F) {
  auto IID = vc::getAnyIntrinsicID(&F);
  switch (IID) {
  case GenXIntrinsic::genx_local_size:
  case GenXIntrinsic::genx_local_id:
  case GenXIntrinsic::genx_local_id16:
  case GenXIntrinsic::genx_group_count:
  case vc::InternalIntrinsic::assert_buffer:
  case vc::InternalIntrinsic::print_buffer:
  case vc::InternalIntrinsic::sync_buffer:
    return true;
  case GenXIntrinsic::genx_get_scoreboard_deltas:
  case GenXIntrinsic::genx_get_scoreboard_bti:
  case GenXIntrinsic::genx_get_scoreboard_depcnt:
    vc::diagnose(F.getContext(), "GenXImplicitParameters",
                 "scoreboarding intrinsics are not supported", &F);
    return false;
  default:
    return false;
  }
}

// For each function, see if it uses an intrinsic that in turn requires an
// implicit kernel argument
// (such as llvm.genx.local.size)
static ImplArgIntrSeq collectImplicitArgIntrinsics(Module &M) {
  ImplArgIntrSeq Workload;
  auto &&ImplArgIntrinsics = make_filter_range(
      M, [](const Function &F) { return isImplicitArgIntrinsic(F); });
  for (Function &Intr : ImplArgIntrinsics)
    llvm::transform(Intr.users(), std::back_inserter(Workload),
                    [](User *U) { return cast<CallInst>(U); });
  return Workload;
}

static IntrIDMap fillUsedIntrMap(const ImplArgIntrSeq &Workload) {
  IntrIDMap UsedIntrInfo;
  for (CallInst *CI : Workload) {
    auto IID = vc::getAnyIntrinsicID(CI->getCalledFunction());
    UsedIntrInfo[CI->getFunction()].insert(IID);
  }
  return UsedIntrInfo;
}

// Replace implicit arg intrinsics collected in \p Workload with a load of
// the corresponding __imparg global variable.
// Fill implicit args usage data.
void CMImpParam::replaceImplicitArgIntrinsics(const ImplArgIntrSeq &Workload) {
  for (CallInst *Intr : Workload)
    replaceWithGlobal(Intr);
}

// Convert to implicit thread payload related intrinsics.
void CMImpParam::ConvertToOCLPayload(Module &M) {
  auto getFn = [=, &M](unsigned ID, Type *Ty) {
    return M.getFunction(GenXIntrinsic::getAnyName(ID, Ty));
  };

  // Convert genx_local_id -> zext(genx_local_id16)
  Type *Ty32 =
      IGCLLVM::FixedVectorType::get(Type::getInt32Ty(M.getContext()), 3);
  Type *Ty16 =
      IGCLLVM::FixedVectorType::get(Type::getInt16Ty(M.getContext()), 3);
  if (auto LIDFn = getFn(GenXIntrinsic::genx_local_id, Ty32)) {
    Function *LID16 = GenXIntrinsic::getGenXDeclaration(
        &M, GenXIntrinsic::genx_local_id16, Ty16);
    for (auto UI = LIDFn->user_begin(); UI != LIDFn->user_end();) {
      auto UInst = dyn_cast<Instruction>(*UI++);
      if (UInst) {
        IRBuilder<> Builder(UInst);
        Value *Val = Builder.CreateCall(LID16, llvm::ArrayRef<Value *>(),
                                        UInst->getName() + ".i16");
        Val = Builder.CreateZExt(Val, Ty32);
        Val->takeName(UInst);
        UInst->replaceAllUsesWith(Val);
        UInst->eraseFromParent();
      }
    }
  }
}

// Recursively visits \p F and its children in call graph that are not fixed
// signature functions. Collects the required info through the traversal.
// \p IsEntry indicates that the provided \p F is the start of traversal and
// the method is not called from itself (we're not inside recursion yet).
template <bool IsEntry> void CallGraphTraverser::visitFunction(Function &F) {
  // If this node has already been processed then return immediately
  if (Visited.count(&F))
    return;

  // Add this node to the already visited list
  Visited.insert(&F);

  // Have to stop on functions which signatures cannot be changed (won't be
  // able to pass an implicit argument as an additional argument there).
  // Entry is an external function by definition, don't stop on entry.
  if constexpr (!IsEntry) {
    IGC_ASSERT_MESSAGE(!vc::isKernel(&F), "kernel call is unexpected");
    if (vc::isFixedSignatureFunc(F)) {
      IGC_ASSERT_MESSAGE(!F.isDeclaration(),
                         "declarations are unexpected: call graph edge cannot "
                         "lead to a declaration");
      if (CalledFixedSignFuncs.has_value())
        // Adding only if undef calling endge haven't been met.
        CalledFixedSignFuncs.value().push_back(&F);
      return;
    }
  }

  // Handle current node: add its used implicit intrinisic IDs if present.
  if (UsedIntr.count(&F))
    CollectedIID.insert(UsedIntr.at(&F).begin(), UsedIntr.at(&F).end());

  // Start the traversal
  const CallGraphNode *N = CG[&F];
  // Inspect all children (recursive)
  for (IGCLLVM::CallRecord CallEdge : *N) {
    // Skipping reference edges.
    if (!CallEdge.first)
      continue;
    Value *CI = IGCLLVM::makeOptional(CallEdge.first).value();
    // Skipping inline asm.
    if (isa<CallInst>(CI) && cast<CallInst>(CI)->isInlineAsm())
      continue;
    if (!isa<CallInst>(CI) || vc::isAnyNonTrivialIntrinsic(CI))
      continue;
    // Returns nullptr in case of indirect call or inline asm which was already
    // considered.
    auto *Child = CallEdge.second->getFunction();
    if (!Child)
      IGC_ASSERT_MESSAGE(
          isa<CallInst>(CI) && cast<CallInst>(CI)->isIndirectCall(),
          "only indirect call is exprected for a null call graph node");
    if (Child && !Child->isDeclaration())
      visitFunction(*Child);
    else
      // Cannot define the set of called functions.
      CalledFixedSignFuncs.reset();
  }
}

static std::string getImplicitArgName(unsigned IID) {
  if (!isPseudoIntrinsic(IID))
    return "impl.arg." + vc::getAnyName(IID, llvm::ArrayRef<Type *>());
  switch (IID) {
  case PseudoIntrinsic::ImplicitArgsBuffer:
    return "impl.arg.impl.args.buffer";
  default:
    IGC_ASSERT_MESSAGE(IID == PseudoIntrinsic::PrivateBase,
                       "there's only private base pseudo intrinsic for now");
    return "impl.arg.private.base";
  }
}

// Process a kernel - loads from a global (and the globals) have already been
// added if required elsewhere (in doInitialization)
// We've already determined that this is a kernel and that it requires some
// implicit arguments adding
CallGraphNode *
CMImpParam::processKernelParameters(Function *F,
                                    const IntrIDSet &UsedImplicits) {
  LLVMContext &Context = F->getContext();

  IGC_ASSERT_MESSAGE(
      vc::isKernel(F),
      "processKernelParameters invoked on non-kernel CallGraphNode");

  AttributeList AttrVec;
  const AttributeList &PAL = F->getAttributes();

  ArgLinearization ArgsLin = GenerateArgsLinearizationInfo(*F);

  // Determine the new argument list
  SmallVector<Type *, 8> ArgTys;

  // First transfer all the explicit arguments from the old kernel
  unsigned ArgIndex = 0;
  for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
       ++I, ++ArgIndex) {
    ArgTys.push_back(I->getType());
    AttrBuilder ArgAttrB(Context, PAL.getParamAttrs(ArgIndex));
    AttrVec = AttrVec.addParamAttributes(Context, ArgIndex, ArgAttrB);
  }

  // Now add all the implicit arguments
  for (unsigned IID : UsedImplicits) {
    ArgTys.push_back(getIntrinRetType(Context, IID));
    // TODO: Might need to also add any attributes from the intrinsic at some
    // point
  }
  // Add types of implicit aggregates linearization
  for (const auto &ArgLin : ArgsLin) {
    for (const auto &LinTy : ArgLin.second)
      ArgTys.push_back(LinTy.Ty);
  }

  FunctionType *NFTy = FunctionType::get(F->getReturnType(), ArgTys, false);
  IGC_ASSERT_MESSAGE((NFTy != F->getFunctionType()),
                     "type out of sync, expect bool arguments)");

  // Add any function attributes
  AttrBuilder B(Context, PAL.getFnAttrs());
  AttrVec = AttrVec.addFnAttributes(Context, B);

  // Create new function body and insert into the module
  Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName());

  LLVM_DEBUG(dbgs() << "CMImpParam: Transforming From:" << *F);
  vc::transferNameAndCCWithNewAttr(AttrVec, *F, *NF);
  F->getParent()->getFunctionList().insert(F->getIterator(), NF);
  vc::transferDISubprogram(*F, *NF);
  LLVM_DEBUG(dbgs() << "  --> To: " << *NF << "\n");

  // Now to splice the body of the old function into the new function
  IGCLLVM::splice(NF, NF->begin(), F);

  // Loop over the argument list, transferring uses of the old arguments to the
  // new arguments, also tranferring over the names as well
  std::unordered_map<const Argument *, Argument *> OldToNewArg;
  Function::arg_iterator I2 = NF->arg_begin();
  for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
       ++I, ++I2) {
    I->replaceAllUsesWith(I2);
    I2->takeName(I);
    OldToNewArg[&*I] = &*I2;
  }

  // Get the insertion point ready for stores to globals
  Instruction &FirstI = *NF->getEntryBlock().begin();
  llvm::SmallVector<uint32_t, 8> ImpKinds;

  for (unsigned IID : UsedImplicits) {
    // We known that for each IID implicit we've already added an arg
    // Rename the arg to something more meaningful here
    IGC_ASSERT_MESSAGE(I2 != NF->arg_end(),
                       "fewer parameters for new function than expected");
    I2->setName(getImplicitArgName(IID));

    auto GlobalsMapIt = GlobalsMap.find(IID);
    if (GlobalsMapIt != GlobalsMap.end()) {
      GlobalVariable *GV = GlobalsMapIt->second;
      // Also insert a new store at the start of the function to the global
      // variable used for this implicit argument intrinsic if such global is
      // present. There are no global for pseudo intrinsics and sometimes for
      // local ID when it is used only for kernel prologue.
      new StoreInst(I2, GV, /*isVolatile=*/false,
                    IGCLLVM::getCorrectAlign(GV->getAlignment()), &FirstI);
    }

    // Prepare the kinds that will go into the metadata
    ImpKinds.push_back(MapToKind(IID));

    ++I2;
  }

  // Collect arguments linearization to store as metadata.
  vc::ArgToImplicitLinearization LinearizedArgs;
  for (const auto &ArgLin : ArgsLin) {
    Argument *ExplicitArg = OldToNewArg[ArgLin.first];
    vc::LinearizedArgInfo &LinearizedArg = LinearizedArgs[ExplicitArg];
    for (const auto &LinTy : ArgLin.second) {
      I2->setName("__arg_lin_" + ExplicitArg->getName() + "." +
                  std::to_string(LinTy.Offset));
      ImpKinds.push_back(static_cast<vc::KernelMetadata::ImpValue>(
                             vc::KernelMetadata::AK_NORMAL) |
                         vc::KernelMetadata::IMP_OCL_LINEARIZATION);
      auto &Ctx = F->getContext();
      auto *I32Ty = Type::getInt32Ty(Ctx);
      ConstantInt *Offset = ConstantInt::get(I32Ty, LinTy.Offset);
      LinearizedArg.push_back({&*I2, Offset});
      ++I2;
    }
  }

#if LLVM_VERSION_MAJOR < 16
  CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
#endif
  CallGraphNode *NF_CGN = CG.getOrInsertFunction(NF);

  if (F->hasDLLExportStorageClass())
    NF->setDLLStorageClass(F->getDLLStorageClass());

  vc::replaceFunctionRefMD(*F, *NF);

  SmallVector<unsigned, 8> ArgKinds;
  vc::KernelMetadata KM{NF};
  // Update arg kinds for the NF.
  for (unsigned i = 0; i < KM.getNumArgs(); ++i) {
    if (LinearizedArgs.count(IGCLLVM::getArg(*NF, i)))
      ArgKinds.push_back(static_cast<vc::KernelMetadata::ImpValue>(
                             vc::KernelMetadata::AK_NORMAL) |
                         vc::KernelMetadata::IMP_OCL_BYVALSVM);
    else
      ArgKinds.push_back(KM.getArgKind(i));
  }
  std::copy(ImpKinds.begin(), ImpKinds.end(), std::back_inserter(ArgKinds));
  KM.updateArgKindsMD(std::move(ArgKinds));
  KM.updateLinearizationMD(std::move(LinearizedArgs));

  F->mutateType(NF->getType());
  F->replaceAllUsesWith(NF);

  // Now that the old function is dead, delete it. If there is a dangling
  // reference to the CallGraphNode, just leave the dead function around
  NF_CGN->stealCalledFunctionsFrom(CG[F]);
  CallGraphNode *CGN = CG[F];
  if (CGN->getNumReferences() == 0)
    delete CG.removeFunctionFromModule(CGN);
  else
    F->setLinkage(Function::ExternalLinkage);

  return NF_CGN;
}

char CMImpParam::ID = 0;
INITIALIZE_PASS_BEGIN(CMImpParam, "CMImpParam",
                      "Transformations required to support implicit arguments",
                      false, false)
INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
INITIALIZE_PASS_END(CMImpParam, "CMImpParam",
                    "Transformations required to support implicit arguments",
                    false, false)

#if LLVM_VERSION_MAJOR < 16
namespace llvm {
Pass *createCMImpParamPass(bool HasPayloadInMemory) {
  return new CMImpParam{HasPayloadInMemory};
}
} // namespace llvm

#else // LLVM_VERSION_MAJOR < 16
PreservedAnalyses CMImpParamPass::run(llvm::Module &M,
                                      llvm::AnalysisManager<llvm::Module> &AM) {
  auto &CG = AM.getResult<CallGraphAnalysis>(M);
  CMImpParam CMIP(CG, HasPayloadInMemory);
  if (CMIP.runOnModule(M))
    return PreservedAnalyses::none();
  return PreservedAnalyses::all();
}
CMImpParamPass::CMImpParamPass() : HasPayloadInMemory{PayloadInMemoryOpt} {};
#endif // LLVM_VERSION_MAJOR < 16
